import json
import random
import os
import numpy as np
import datetime
from scipy import stats
import tensorflow as tf

from config.caernn_config import CAERNNCongfig
from config.clvrnn_config import CLVRNNCongfig
from config.cvae_config import CVAECongfig
from config.cvrnn_config import CVRNNCongfig
from config.lstm_prediction_config import LSTMPredictConfig
from config.multi_agent_config import MultiAgentEmbedConfig
from config.stats_encoder_config import EncoderConfig
from support.embedding_tools import plot_embeddings, dimensional_reduction, aggregate_positions_within_cluster, \
    get_player_cluster, get_features_cluster
from support.model_tools import get_model_and_log_name, get_data_name, validate_games_player_id, \
    validate_games_embedding


def illustrate_embeddings(encoder_values, player_index_list,
                          player_basic_info_dir, model_msg,
                          location_features_all, action_features_all,
                          manpower_features_all, period_feature_all,
                          cluster_type='position'):
    # cluster_type = 'name'
    if cluster_type == 'name':
        perplexity = 10
        size = 5
        player_name_selected = {
                                # 'Jack Eichel', # center
                                # 'Brent Burns', # defense
                                'Rasmus Dahlin': 'R.Dahlin', # defense
                                'Rasmus Ristolainen': 'R.Ristolainen', # defense
                                'Justin Braun': 'J.Braun', # defense
                                # 'Logan Couture', # center
                                'Zach Bogosian': 'Z.Bogosian', # defense
                                # 'Sam Reinhart', # Centre/Right Wing
                                # 'Marc-Edouard Vlasic', # defense
                                'Jake McCabe': 'J.McCabe', # defense
                                # 'Tomas Hertl', # center
                                # 'Erik Karlsson', # defense
                                # 'Brenden Dillon', # defense
                                # 'Marco Scandella', # defense
                                # 'Kevin Labanc', # Right Wing
                                # 'Dylan Larkin', # Right Wing
                                # 'Joe Pavelski', # Center/Right Wing
                                # 'Timo Meier', # Right Wing
                                # 'Evan Rodrigues' # Left Wing
                                # 'Nathan Beaulieu' # defense
        }
        # center, defense, defense, defense, center, Centre/Right Wing, defense, Right Wing, Left Wing
        cluster_list, player_cluster_mapping, encoder_values = get_player_cluster(player_index_list=player_index_list,
                                                                  player_basic_info_dir=player_basic_info_dir,
                                                                  cluster_type=cluster_type,
                                                                  all_encoder_values = encoder_values,
                                                                  cluster_selected=player_name_selected.keys(),
                                                                  )
        legend_cluster_mapping = player_name_selected
    elif cluster_type == 'position':
        perplexity = 30
        legend_cluster_mapping = None
        size=2

        cluster_list, player_cluster_mapping, encoder_values = get_player_cluster(player_index_list=player_index_list,
                                                                  player_basic_info_dir=player_basic_info_dir,
                                                                  cluster_type=cluster_type,
                                                                  all_encoder_values=encoder_values)
    elif cluster_type == 'od-zone' or cluster_type=='locations':
        perplexity = 30
        legend_cluster_mapping = None
        size=0.1
        cluster_list, player_cluster_mapping, encoder_values = get_features_cluster(
            game_features_all=location_features_all,
            cluster_type=cluster_type,
            all_encoder_values=encoder_values)
    elif cluster_type == 'manpower':
        perplexity = 30
        legend_cluster_mapping = None
        size=0.1
        cluster_list, player_cluster_mapping, encoder_values = get_features_cluster(
            game_features_all=manpower_features_all,
            cluster_type=cluster_type,
            all_encoder_values=encoder_values)
    elif cluster_type == 'period':
        perplexity = 30
        legend_cluster_mapping = None
        size=0.1
        cluster_list, player_cluster_mapping, encoder_values = get_features_cluster(
            game_features_all=period_feature_all,
            cluster_type=cluster_type,
            all_encoder_values=encoder_values)
    elif 'action' in cluster_type:
        perplexity = 50
        size=2
        offensive_action = {'shot': 'Shot', 'pressure':'Pressure',
                            'dumpin':'Dumpin', 'check':'Check',
                            'assist': 'Assist'}
        neutral_action = {'pass': 'Pass', 'reception': 'Reception',
                          'carry': 'Carry', 'faceoff': 'Face-off',
                          'penalty': 'Penalty'}
        defensive_action = {'lpr': 'LPR',
                            'block': 'Block',
                            'puckprotection': 'PP',
                            'dumpout': 'Dumpout',
                            'receptionprevention': 'RP'}
        if cluster_type == 'action-o':
            action_selected = offensive_action
        elif cluster_type == 'action-n':
            action_selected = neutral_action
        elif cluster_type == 'action-d':
            action_selected = defensive_action
        legend_cluster_mapping = action_selected

        cluster_list, player_cluster_mapping, encoder_values = get_features_cluster(
            game_features_all=action_features_all,
            cluster_type=cluster_type,
            all_encoder_values=encoder_values,
            cluster_selected=action_selected.keys(),
        )
    else:
        raise ValueError("Unknown cluster type {0}".format(cluster_type))
    # if 'clvrnn' in model_msg:
    #     [action_encoder_values, state_encoder_values] = np.split(encoder_values, [32], axis=1)
    #
    #     dr_embedding_action = dimensional_reduction(action_encoder_values, dr_method='TSNE')
    #     plot_embeddings(data=dr_embedding_action, cluster_number=cluster_list,
    #                     player_cluster_mapping=player_cluster_mapping, model_msg=model_msg+'_action'+'_'+cluster_type)
    #
    #     dr_embedding_state = dimensional_reduction(state_encoder_values, dr_method='TSNE')
    #     plot_embeddings(data=dr_embedding_state, cluster_number=cluster_list,
    #                     player_cluster_mapping=player_cluster_mapping, model_msg=model_msg+'_state'+'_'+cluster_type)

    dr_embedding = dimensional_reduction(encoder_values, dr_method='TSNE', perplexity=perplexity)
    # dr_embedding = encoder_values
    plot_embeddings(data=dr_embedding, cluster_number=cluster_list,
                    player_cluster_mapping=player_cluster_mapping,
                    legend_cluster_mapping=legend_cluster_mapping,
                    size=size,
                    model_msg=model_msg+"_"+cluster_type)

def significance_test(testing_target, testing_objects):
    # testing_target = testing_target[:, :128]
    for testing_object in testing_objects:
        a = testing_target.flatten()
        b = testing_object.flatten()
        results = stats.ttest_rel(testing_target[:, :128].flatten(), testing_object[:, :128].flatten())
        print(results)
    return


if __name__ == '__main__':
    local_test_flag = False
    model_type_all = ['multi_agent']
    model_number = 1801
    player_info = '_pid'
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"

    player_basic_info_dir = '../sport_resource/ice_hockey_201819/player_info_2018_2019.json'
    game_date_dir = '../sport_resource/ice_hockey_201819/game_dates_2018_2019.json'
    player_box_score_dir = '../sport_resource/ice_hockey_201819/Scale_NHL_players_game_summary_201819.csv'
    data_store_dir = "/Local-Scratch/oschulte/Galen/Ice-hockey-data/2018-2019/"
    dir_games_all = os.listdir(data_store_dir)

    date_msg = ''
    all_type_embedding = []

    for model_type in model_type_all:

        if "clvrnn" in model_type:
            model_msg = "{0}_{1}".format(model_type, model_number)
            embed_mode = '_embed_random_v2'
            predicted_target = '_PlayerLocalId_predict_nex_goal'  # playerId_
            player_id_cluster_dir = '../sport_resource/ice_hockey_201819/local_player_id_2018_2019.json'
            icehockey_clvrnn_config_path = "../environment_settings/icehockey_clvrnn{0}_config{1}{2}.yaml". \
                format(predicted_target, player_info, embed_mode)
            icehockey_model_config = CLVRNNCongfig.load(icehockey_clvrnn_config_path)
        elif model_type == 'caernn':
            model_msg = "{0}_{1}".format(model_type, model_number)
            predicted_target = '_PlayerLocalId_predict_nex_goal'  # playerId_
            player_id_cluster_dir = '../sport_resource/ice_hockey_201819/local_player_id_2018_2019.json'
            icehockey_config_path = "../environment_settings/icehockey_caernn{0}_config{1}.yaml". \
                format(predicted_target, player_info)
            icehockey_model_config = CAERNNCongfig.load(icehockey_config_path)

        elif model_type == 'cvrnn':
            model_msg = "{0}_{1}".format(model_type, model_number)
            embed_mode = '_embed_random'
            predicted_target = '_PlayerLocalId_predict_nex_goal'
            player_id_cluster_dir = '../sport_resource/ice_hockey_201819/local_player_id_2018_2019.json'
            skip_decoder_condition = True
            skip_msg = '_NoDecoderConditionZ2' if skip_decoder_condition else ''
            if skip_decoder_condition:
                date_msg = '_' + '2020-05-03-00' + skip_msg
            else:
                date_msg = '_' + '2020-05-03-00'

            icehockey_cvrnn_config_path = "../environment_settings/" \
                                          "icehockey_cvrnn{0}_config{1}{2}{3}.yaml".format(predicted_target,
                                                                                           player_info,
                                                                                           embed_mode,
                                                                                           skip_msg)
            icehockey_model_config = CVRNNCongfig.load(icehockey_cvrnn_config_path)

        elif model_type == 'cvae':
            model_msg = "{0}_{1}".format(model_type, model_number)
            predicted_target = '_PlayerLocalId_predict_next_goal'  # playerId_
            player_id_cluster_dir = '../sport_resource/ice_hockey_201819/local_player_id_2018_2019.json'
            icehockey_config_path = "../environment_settings/icehockey_cvae_lstm{0}_config{1}.yaml".format(
                predicted_target, player_info)
            icehockey_model_config = CVAECongfig.load(icehockey_config_path)

        elif model_type == 'vhe':
            model_msg = "{0}_{1}".format(model_type, model_number)
            predicted_target = '_PlayerLocalId_predict_next_goal'  # playerId_
            player_id_cluster_dir = '../sport_resource/ice_hockey_201819/local_player_id_2018_2019.json'
            icehockey_config_path = "../environment_settings/icehockey_vhe_lstm{0}_config{1}.yaml".format(
                predicted_target, player_info)
            icehockey_model_config = CVAECongfig.load(icehockey_config_path)

        elif model_type == 'encoder':
            model_msg = "{0}_{1}".format(model_type, model_number)
            predicted_target = '_PlayerLocalId_predict_next_goal'
            player_id_cluster_dir = '../sport_resource/ice_hockey_201819/local_player_id_2018_2019.json'
            icehockey_encoder_config_path = "../environment_settings/" \
                                            "icehockey_stats_lstm_encoder{0}" \
                                            "_config{1}.yaml".format(predicted_target, player_info)
            icehockey_model_config = EncoderConfig.load(icehockey_encoder_config_path)
        elif model_type == 'multi_agent':
            predicted_target = '_PlayerLocalId'
            icehockey_config_path = "../environment_settings/ice_hockey_multi_agent{0}.yaml".format(player_info)
            player_id_cluster_dir = '../sport_resource/ice_hockey_201819/local_player_id_2018_2019.json'
            icehockey_model_config = MultiAgentEmbedConfig.load(icehockey_config_path)
        else:
            raise ValueError("uknown model catagoery {0}".format(model_type))

        saved_network_dir, log_dir = get_model_and_log_name(config=icehockey_model_config,
                                                            model_catagoery=model_type,
                                                            running_number=0,
                                                            date_msg=date_msg)

        # data_name = get_data_name(icehockey_model_config, model_category, model_number)

        print(model_type + '_' + str(model_number) + player_info)

        if local_test_flag:
            data_store_dir = "/Users/liu/Desktop/Ice-hokcey-data-sample/feature-sample"
        else:
            data_store_dir = icehockey_model_config.Learn.save_mother_dir \
                             + "/oschulte/Galen/Ice-hockey-data/2018-2019/"

        testing_dir_games_all = []
        with open(saved_network_dir + '/testing_file_dirs_all.csv', 'rb') as f:
            testing_dir_all = f.readlines()
        for testing_dir in testing_dir_all:
            testing_dir_games_all.append(str(int(testing_dir)))
        testing_dir_games_all = testing_dir_games_all[:5]

        all_embedding, \
        all_player_index, \
        location_features_all, \
        action_features_all, \
        manpower_features_all, \
        period_feature_all = validate_games_embedding(config=icehockey_model_config,
                                                    data_store_dir=data_store_dir,
                                                    dir_all=testing_dir_games_all,
                                                    player_basic_info_dir=player_basic_info_dir,
                                                    game_date_dir=game_date_dir,
                                                    player_box_score_dir=player_box_score_dir,
                                                    model_number=model_number,
                                                    player_id_cluster_dir=player_id_cluster_dir,
                                                    saved_network_dir=saved_network_dir,
                                                    model_category=model_type,
                                                    source_data_path='/Local-Scratch/oschulte/Galen/2018-2019/',
                                                    )
        all_type_embedding.append(all_embedding)

    # significance_test(testing_target=all_type_embedding[0], testing_objects=all_type_embedding[1:])

        illustrate_embeddings(encoder_values=all_embedding,
                              player_index_list=all_player_index,
                              player_basic_info_dir=player_basic_info_dir,
                              model_msg=model_msg,
                              location_features_all=location_features_all,
                              action_features_all =action_features_all,
                              manpower_features_all=manpower_features_all,
                              period_feature_all =period_feature_all,
                              cluster_type='position'
                              )
